Skip to content

Conversation

a-r-r-o-w
Copy link
Contributor

What does this PR do?

Adds support for SparseCtrl to be loaded with single file.

Code
import torch
from diffusers import AnimateDiffSparseControlNetPipeline
from diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, load_image

model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3"
controlnet_id = "guoyww/animatediff-sparsectrl-scribble"
lora_adapter_id = "guoyww/animatediff-motion-lora-v1-5-3"
vae_id = "stabilityai/sd-vae-ft-mse"
device = "cuda"

motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)
controlnet = SparseControlNetModel.from_single_file("/raid/aryan/hub/models--guoyww--animatediff/snapshots/fdfe36afa161e51b3e9c24022b0e368d59e7345e/v3_sd15_sparsectrl_scribble.ckpt", torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)
scheduler = DPMSolverMultistepScheduler.from_pretrained(
    model_id,
    subfolder="scheduler",
    beta_schedule="linear",
    algorithm_type="dpmsolver++",
    use_karras_sigmas=True,
)
pipe = AnimateDiffSparseControlNetPipeline.from_pretrained(
    model_id,
    motion_adapter=motion_adapter,
    controlnet=controlnet,
    vae=vae,
    scheduler=scheduler,
    torch_dtype=torch.float16,
).to(device)
pipe.load_lora_weights(lora_adapter_id, adapter_name="motion_lora")
pipe.fuse_lora(lora_scale=1.0)

prompt = "an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality"
negative_prompt = "low quality, worst quality, letterboxed"

image_files = [
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-1.png",
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-2.png",
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-3.png",
]
condition_frame_indices = [0, 8, 15]
conditioning_frames = [load_image(img_file) for img_file in image_files]

video = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    num_inference_steps=25,
    conditioning_frames=conditioning_frames,
    controlnet_conditioning_scale=1.0,
    controlnet_frame_indices=condition_frame_indices,
    generator=torch.Generator().manual_seed(1337),
).frames[0]
export_to_gif(video, "output.gif")
Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w
Copy link
Contributor Author

@DN6 I currently get the following error when trying to load with single file:

Logs
The config attributes {'addition_embed_type': None, 'addition_embed_type_num_heads': 64, 'addition_time_embed_dim': None, 'class_embed_type': None, 'encoder_hid_dim': None, 'encoder_hid_dim_type': None, 'num_class_embeds': None, 'projection_class_embeddings_input_dim': None} were passed to SparseControlNetModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  7.26it/s]
Traceback (most recent call last):
  File "/home/aryan/work/diffusers/dump.py", line 137, in <module>
    ).to(device)
  File "/home/aryan/work/diffusers/src/diffusers/pipelines/pipeline_utils.py", line 431, in to
    module.to(device, dtype)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1173, in to
    return self._apply(convert)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 779, in _apply
    module._apply(fn)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 779, in _apply
    module._apply(fn)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 779, in _apply
    module._apply(fn)
  [Previous line repeated 4 more times]
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 804, in _apply
    param_applied = fn(param)
  File "/raid/aryan/diffusers-work-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1166, in convert
    raise NotImplementedError(
NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.

This works when accelerate isn't installed but fails if it is. Not really too sure on how this can be fixed so would appreciate any advice @DN6

@a-r-r-o-w a-r-r-o-w requested a review from DN6 August 3, 2024 22:53
"stable_cascade_stage_c": "clip_txt_mapper.weight",
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe",
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
Copy link
Contributor Author

@a-r-r-o-w a-r-r-o-w Aug 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed this to 0 because it still identifies animatediff, I think, and since SparseCtrl does not have 1

@DN6
Copy link
Collaborator

DN6 commented Aug 5, 2024

@a-r-r-o-w You have required keys missing from the checkpoint. Accelerate loads the weights initially as meta tensors. The ones that don't have a corresponding weight in the checkpoint remain meta tensors and cannot be moved to device. When you load directly with torch, those weights just stay as random tensors.

Is the checkpoint conversion function for SparseCntrl the same as the one for the Motion Adapter? Here are the missing keys 👇🏽

down_blocks.0.motion_modules.0.transformer_blocks.0.norm2.weight
down_blocks.0.motion_modules.0.transformer_blocks.0.norm2.bias
down_blocks.0.motion_modules.0.transformer_blocks.0.attn2.to_q.weight
down_blocks.0.motion_modules.0.transformer_blocks.0.attn2.to_k.weight
down_blocks.0.motion_modules.0.transformer_blocks.0.attn2.to_v.weight
down_blocks.0.motion_modules.0.transformer_blocks.0.attn2.to_out.0.weight
down_blocks.0.motion_modules.0.transformer_blocks.0.attn2.to_out.0.bias
down_blocks.0.motion_modules.1.transformer_blocks.0.norm2.weight
down_blocks.0.motion_modules.1.transformer_blocks.0.norm2.bias
down_blocks.0.motion_modules.1.transformer_blocks.0.attn2.to_q.weight
down_blocks.0.motion_modules.1.transformer_blocks.0.attn2.to_k.weight
down_blocks.0.motion_modules.1.transformer_blocks.0.attn2.to_v.weight
down_blocks.0.motion_modules.1.transformer_blocks.0.attn2.to_out.0.weight
down_blocks.0.motion_modules.1.transformer_blocks.0.attn2.to_out.0.bias
down_blocks.1.motion_modules.0.transformer_blocks.0.norm2.weight
down_blocks.1.motion_modules.0.transformer_blocks.0.norm2.bias
down_blocks.1.motion_modules.0.transformer_blocks.0.attn2.to_q.weight
down_blocks.1.motion_modules.0.transformer_blocks.0.attn2.to_k.weight
down_blocks.1.motion_modules.0.transformer_blocks.0.attn2.to_v.weight
down_blocks.1.motion_modules.0.transformer_blocks.0.attn2.to_out.0.weight
down_blocks.1.motion_modules.0.transformer_blocks.0.attn2.to_out.0.bias
down_blocks.1.motion_modules.1.transformer_blocks.0.norm2.weight
down_blocks.1.motion_modules.1.transformer_blocks.0.norm2.bias
down_blocks.1.motion_modules.1.transformer_blocks.0.attn2.to_q.weight
down_blocks.1.motion_modules.1.transformer_blocks.0.attn2.to_k.weight
down_blocks.1.motion_modules.1.transformer_blocks.0.attn2.to_v.weight
down_blocks.1.motion_modules.1.transformer_blocks.0.attn2.to_out.0.weight
down_blocks.1.motion_modules.1.transformer_blocks.0.attn2.to_out.0.bias
down_blocks.2.motion_modules.0.transformer_blocks.0.norm2.weight
down_blocks.2.motion_modules.0.transformer_blocks.0.norm2.bias
down_blocks.2.motion_modules.0.transformer_blocks.0.attn2.to_q.weight
down_blocks.2.motion_modules.0.transformer_blocks.0.attn2.to_k.weight
down_blocks.2.motion_modules.0.transformer_blocks.0.attn2.to_v.weight
down_blocks.2.motion_modules.0.transformer_blocks.0.attn2.to_out.0.weight
down_blocks.2.motion_modules.0.transformer_blocks.0.attn2.to_out.0.bias
down_blocks.2.motion_modules.1.transformer_blocks.0.norm2.weight
down_blocks.2.motion_modules.1.transformer_blocks.0.norm2.bias
down_blocks.2.motion_modules.1.transformer_blocks.0.attn2.to_q.weight
down_blocks.2.motion_modules.1.transformer_blocks.0.attn2.to_k.weight
down_blocks.2.motion_modules.1.transformer_blocks.0.attn2.to_v.weight
down_blocks.2.motion_modules.1.transformer_blocks.0.attn2.to_out.0.weight
down_blocks.2.motion_modules.1.transformer_blocks.0.attn2.to_out.0.bias
down_blocks.3.motion_modules.0.transformer_blocks.0.norm2.weight
down_blocks.3.motion_modules.0.transformer_blocks.0.norm2.bias
down_blocks.3.motion_modules.0.transformer_blocks.0.attn2.to_q.weight
down_blocks.3.motion_modules.0.transformer_blocks.0.attn2.to_k.weight
down_blocks.3.motion_modules.0.transformer_blocks.0.attn2.to_v.weight
down_blocks.3.motion_modules.0.transformer_blocks.0.attn2.to_out.0.weight
down_blocks.3.motion_modules.0.transformer_blocks.0.attn2.to_out.0.bias
down_blocks.3.motion_modules.1.transformer_blocks.0.norm2.weight
down_blocks.3.motion_modules.1.transformer_blocks.0.norm2.bias
down_blocks.3.motion_modules.1.transformer_blocks.0.attn2.to_q.weight
down_blocks.3.motion_modules.1.transformer_blocks.0.attn2.to_k.weight
down_blocks.3.motion_modules.1.transformer_blocks.0.attn2.to_v.weight
down_blocks.3.motion_modules.1.transformer_blocks.0.attn2.to_out.0.weight
down_blocks.3.motion_modules.1.transformer_blocks.0.attn2.to_out.0.bias

Looks like the downblock weights are being missed?

@a-r-r-o-w
Copy link
Contributor Author

a-r-r-o-w commented Aug 5, 2024

Hmm, interesting. The conversion script mapping dict is pretty much the same as that of animatediff (and you can see I've done a strict match in conversion script)

I'll take a look again

@DN6
Copy link
Collaborator

DN6 commented Aug 5, 2024

I think temporal_double_self_attention in the motion down blocks got removed when the blocks were refactored. It led to an extra attention layer being created in the model. Added the argument it back in.

@a-r-r-o-w
Copy link
Contributor Author

Oh... Really sorry for the oversight - I did not realize it was needed and had something else in mind when I removed it. I just tested the latest changes and can verify it works as expected, thanks!

@DN6 DN6 merged commit f6df224 into main Aug 7, 2024
@a-r-r-o-w a-r-r-o-w deleted the animatediff/sparsectrl-singlefile branch August 7, 2024 06:43
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* allow sparsectrl to be loaded with single file

* update

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants